{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__This notebook__ trains resnet18 from scratch on CIFAR10 dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=1\n",
      "editable_layer3_2019.09.19_23:06:14\n",
      "PyTorch version: 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n",
    "import os, sys, time\n",
    "sys.path.insert(0, '..')\n",
    "import lib\n",
    "\n",
    "import numpy as np\n",
    "import torch, torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import random\n",
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.random.manual_seed(42)\n",
    "\n",
    "import time\n",
    "from resnet import ResNet18\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "experiment_name = 'editable_layer3'\n",
    "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n",
    "print(experiment_name)\n",
    "print(\"PyTorch version:\", torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "from torchvision import transforms, datasets\n",
    "\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)\n",
    "\n",
    "testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "X_test, y_test = map(torch.cat, zip(*list(testloader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = lib.Editable(\n",
    "    module=ResNet18(), loss_function=lib.contrastive_cross_entropy,\n",
    "    get_editable_parameters=lambda module: module.layer3.parameters(),\n",
    "    optimizer=lib.IngraphRMSProp(\n",
    "        learning_rate=1e-3, beta=nn.Parameter(torch.tensor(0.5, dtype=torch.float32)), \n",
    "    ), max_steps=10,\n",
    "\n",
    ").to(device)\n",
    "\n",
    "trainer = lib.EditableTrainer(model, F.cross_entropy, experiment_name=experiment_name, max_norm=10)\n",
    "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', '<br>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "af4871f821da4d80b68dd4841955e2ac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=391), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import tqdm_notebook, tnrange\n",
    "from IPython.display import clear_output\n",
    "\n",
    "val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n",
    "min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']\n",
    "early_stopping_epochs = 500\n",
    "number_of_epochs_without_improvement = 0\n",
    "\n",
    "def edit_generator():\n",
    "    while True:\n",
    "        for xb, yb in torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2):\n",
    "            yield xb.to(device), torch.randint_like(yb, low=0, high=len(classes), device=device)\n",
    "\n",
    "edit_generator = edit_generator()\n",
    "\n",
    "\n",
    "while True:\n",
    "    for x_batch, y_batch in tqdm_notebook(trainloader):\n",
    "        trainer.step(x_batch.to(device), y_batch.to(device), *next(edit_generator))\n",
    "        \n",
    "    val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n",
    "    clear_output(True)\n",
    "    \n",
    "    error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']\n",
    "    \n",
    "    number_of_epochs_without_improvement += 1\n",
    "    \n",
    "    \n",
    "    if error_rate < min_error:\n",
    "        trainer.save_checkpoint(tag='best_val_error')\n",
    "        min_error = error_rate\n",
    "        number_of_epochs_without_improvement = 0\n",
    "        \n",
    "    if drawdown < min_drawdown:\n",
    "        trainer.save_checkpoint(tag='best_drawdown')\n",
    "        min_drawdown = drawdown\n",
    "        number_of_epochs_without_improvement = 0\n",
    "    \n",
    "    trainer.save_checkpoint()\n",
    "    trainer.remove_old_temp_checkpoints()\n",
    "\n",
    "    if number_of_epochs_without_improvement > early_stopping_epochs:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lib import evaluate_quality\n",
    "\n",
    "np.random.seed(9)\n",
    "indices = np.random.permutation(len(X_test))[:1000]\n",
    "X_edit = X_test[indices].clone().to(device)\n",
    "y_edit = torch.tensor(np.random.randint(0, 10, size=y_test[indices].shape), device=device)\n",
    "metrics = evaluate_quality(editable_model, X_test, y_test, X_edit, y_edit, batch_size=512)\n",
    "for key in sorted(metrics.keys()):\n",
    "    print('{}\\t:{:.5}'.format(key, metrics[key]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
